# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from functools import wraps
import contextvars
import contextlib

import towhee.runtime.pipeline_loader as pipe_loader
import towhee.runtime.node_config as nd_conf
from towhee.utils.log import engine_log


_AUTO_CONFIG_VAR: contextvars.ContextVar = contextvars.ContextVar('auto_config_var')


@contextlib.contextmanager
def set_config_name(name: str):
    token = _AUTO_CONFIG_VAR.set(name)
    yield
    _AUTO_CONFIG_VAR.reset(token)


def get_config_name():
    try:
        return _AUTO_CONFIG_VAR.get()
    except:  # pylint: disable=bare-except
        return None


# pylint: disable=invalid-name
class AutoConfig:

    """
    Auto configuration.
    """

    _REGISTERED_CONFIG = {}
    _lock = threading.Lock()

    def __init__(self):
        raise EnvironmentError(
            'AutoConfig is designed to be instantiated, please using the `AutoConfig.LocalCPUConfig()` etc.'
        )

    @staticmethod
    def register(config):
        @wraps(config)
        def wrapper(*args, **kwargs):
            return config(*args, **kwargs)

        name = get_config_name()
        if name is not None:
            AutoConfig._REGISTERED_CONFIG[name] = wrapper
        return wrapper

    @staticmethod
    def load_config(name: str, *args, **kwargs):
        """
        Load config from pre-defined pipeline.

        Examples:
            >>> from towhee import AutoConfig
            >>> config = AutoConfig.load_config('sentence_embedding')
            SentenceSimilarityConfig(model='all-MiniLM-L6-v2', openai_api_key=None, customize_embedding_op=None, normalize_vec=True, device=-1)
        """
        with AutoConfig._lock:
            if name in AutoConfig._REGISTERED_CONFIG:
                return AutoConfig._REGISTERED_CONFIG[name](*args, **kwargs)

            with set_config_name(name):
                pipe_loader.PipelineLoader.load_pipeline(name)
            if name in AutoConfig._REGISTERED_CONFIG:
                return AutoConfig._REGISTERED_CONFIG[name](*args, **kwargs)
            engine_log.error('Can not find config: %s', name)
            return None

    @staticmethod
    def LocalCPUConfig():
        """
        Auto configuration to run with local CPU.

        Examples:
            >>> from towhee import pipe, AutoConfig
            >>> p = (pipe.input('a')
            ...          .flat_map('a', 'b', lambda x: [y for y in x], config=AutoConfig.LocalCPUConfig())
            ...          .output('b'))
        """
        return nd_conf.TowheeConfig.set_local_config(device=-1)

    @staticmethod
    def LocalGPUConfig(device: int = 0):
        """
        Auto configuration to run with local GPU.

        Args:
            device (`int`): the number of GPU device, defaults to 0.

        Examples:
            >>> from towhee import pipe, AutoConfig
            >>> p = (pipe.input('url')
            ...          .map('url', 'image', ops.image_decode.cv2())
            ...          .map('image', 'vec', ops.image_embedding.timm(model_name='resnet50'), config=AutoConfig.LocalGPUConfig())
            ...          .output('vec')
            ... )
        """
        return nd_conf.TowheeConfig.set_local_config(device=device)

    @staticmethod
    def TritonCPUConfig(num_instances_per_device: int = 1,
                        max_batch_size: int = None,
                        batch_latency_micros: int = None,
                        preferred_batch_size: list = None):
        """
        Auto configuration to run with triton server(CPU).

        Args:
            max_batch_size(`int`):
                maximum batch size, defaults to None, and it will be auto-generated by triton.
            batch_latency_micros(`int`):
                time to the request, in microseconds, defaults to None, and it will auto-generated by triton.
            num_instances_per_device(`int`):
                the number of instances per device, defaults to 1.
            preferred_batch_size(`list`):
                preferred batch sizes for dynamic batching, defaults to None, and it will be auto-generated by triton.

        Examples:
            >>> from towhee import pipe, AutoConfig
            >>> p = (pipe.input('url')
            ...          .map('url', 'image', ops.image_decode.cv2())
            ...          .map('image', 'vec', ops.image_embedding.timm(model_name='resnet50'), config=AutoConfig.TritonCPUConfig())
            ...          .output('vec')
            ... )

            You can also to set the configuration:

            >>> from towhee import pipe, AutoConfig
            >>> config = AutoConfig.TritonCPUConfig(num_instances_per_device=3,
            ...                                     max_batch_size=128,
            ...                                     batch_latency_micros=100000,
            ...                                     preferred_batch_size=[8, 16])
            >>> p = (pipe.input('url')
            ...          .map('url', 'image', ops.image_decode.cv2())
            ...          .map('image', 'vec', ops.image_embedding.timm(model_name='resnet50'), config=config)
            ...          .output('vec')
            ... )
        """
        return nd_conf.TowheeConfig.set_triton_config(device_ids=None,
                                                      num_instances_per_device=num_instances_per_device,
                                                      max_batch_size=max_batch_size,
                                                      batch_latency_micros=batch_latency_micros,
                                                      preferred_batch_size=preferred_batch_size)

    @staticmethod
    def TritonGPUConfig(device_ids: list = None,
                        num_instances_per_device: int = 1,
                        max_batch_size: int = None,
                        batch_latency_micros: int = None,
                        preferred_batch_size: list = None):
        """
        Auto configuration to run with triton server(GPUs).

        Args:
           device_ids(`list`):
               list of GPUs, defaults to [0].
           max_batch_size(`int`):
               maximum batch size, defaults to None, and it will be auto-generated by triton.
           batch_latency_micros(`int`):
               time to the request, in microseconds, defaults to None, and it will auto-generated by triton.
           num_instances_per_device(`int`):
               the number of instances per device, defaults to 1.
           preferred_batch_size(`list`):
               preferred batch sizes for dynamic batching, defaults to None, and it will be auto-generated by triton.

        Examples:
            >>> from towhee import pipe, AutoConfig
            >>> p = (pipe.input('url')
            ...          .map('url', 'image', ops.image_decode.cv2())
            ...          .map('image', 'vec', ops.image_embedding.timm(model_name='resnet50'), config=AutoConfig.TritonGPUConfig())
            ...          .output('vec')
            ... )

            You can also to set the configuration:

            >>> from towhee import pipe, AutoConfig
            >>> config = AutoConfig.TritonGPUConfig(device_ids=[0, 1],
            ...                                     num_instances_per_device=3,
            ...                                     max_batch_size=128,
            ...                                     batch_latency_micros=100000,
            ...                                     preferred_batch_size=[8, 16])
            >>> p = (pipe.input('url')
            ...          .map('url', 'image', ops.image_decode.cv2())
            ...          .map('image', 'vec', ops.image_embedding.timm(model_name='resnet50'), config=config)
            ...          .output('vec')
            ... )
        """
        if device_ids is None:
            device_ids = [0]
        return nd_conf.TowheeConfig.set_triton_config(device_ids=device_ids,
                                                      num_instances_per_device=num_instances_per_device,
                                                      max_batch_size=max_batch_size,
                                                      batch_latency_micros=batch_latency_micros,
                                                      preferred_batch_size=preferred_batch_size)


AutoConfig._REGISTERED_CONFIG['LocalCPUConfig'] = AutoConfig.LocalCPUConfig  # pylint: disable=protected-access
AutoConfig._REGISTERED_CONFIG['LocalGPUConfig'] = AutoConfig.LocalGPUConfig  # pylint: disable=protected-access
AutoConfig._REGISTERED_CONFIG['TritonCPUConfig'] = AutoConfig.TritonCPUConfig  # pylint: disable=protected-access
AutoConfig._REGISTERED_CONFIG['TritonGPUConfig'] = AutoConfig.TritonGPUConfig  # pylint: disable=protected-access
